from glob import glob

from natsort import natsorted
from torch.nn import BCELoss
from torch.optim import SGD, Adam
from dn_lr.dn.utils import import_from_pickle_file
import torch


def get_model_scratch(model_directory):
    """
    Import the models from given directory
    :param model_directory: the directory where models are stored
    :param device: the device to which we want to use for evaluation
    :return: list of models according to the index
    """
    from dn_lr.dn.model_class_scratch_lr.logistic_regression.model import LogisticRegression
    import sys
    model_directory = f'{model_directory}*'
    num_models = len(glob(model_directory))
    models = list(range(num_models))
    for true_label_index, file in enumerate(natsorted(glob(model_directory))):
        true_label_for_this_model = int(file.split("_")[-1])
        assert true_label_for_this_model == true_label_index, f"Incorrect model index {true_label_index}, {true_label_for_this_model}"
        this_model = import_from_pickle_file(file)
        models[true_label_index] = this_model
    return models


def get_cuda_status_as_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return device


def get_model_pytorch_for_inference(model_directory):
    """
    Import the models from given directory
    :param model_directory: the directory where models are stored
    :param device: the device to which we want to use for evaluation
    :return: list of models according to the index
    """
    device = get_cuda_status_as_device()
    model_directory = f'{model_directory}*'
    num_models = len(glob(model_directory))
    models = list(range(num_models))
    model_params = [None for i in range(num_models)]
    for true_label_index, file in enumerate(natsorted(glob(model_directory))):
        true_label_for_this_model = int(file.split("_")[-1])
        if device == torch.device("cpu"):
            this_model = torch.load(file)
        else:
            this_model = torch.load(file, map_location=torch.device('cpu'))
        models[true_label_index] = this_model
        weight, bias = get_weight_and_bias_from_torch(this_model)
        model_params[true_label_index] = {"weights": weight, "bias": bias}
    return models, model_params


def convert_model_scratch_to_pytorch_model(models_scratch, num_classes, device):
    from dn_lr.dn.model_class.logistic_regression_model import LogisticRegressionModel
    pytorch_models = []
    input_dim = num_classes * 2 - 1
    output_dim = 1
    for each_model in models_scratch:
        weights = each_model.weights
        bias = each_model.bias
        new_model = LogisticRegressionModel(input_dim, output_dim).to(device)
        new_model.init_weights_and_bias(weights, bias, device)
        pytorch_models.append(new_model)
    return pytorch_models


def initialize_all_models(num_classes, device, LEARNING_RATE, pretrained=False, model_directory=None):
    from dn_lr.dn.model_class.logistic_regression_model import LogisticRegressionModel
    criterion = BCELoss()
    if pretrained:
        all_models_from_scratch = get_model_scratch(model_directory)
        all_models = convert_model_scratch_to_pytorch_model(all_models_from_scratch, num_classes, device)
    else:
        input_dim = num_classes * 2 - 1
        output_dim = 1
        all_models = []
        for _ in range(num_classes):
            this_model = LogisticRegressionModel(input_dim, output_dim).to(device)
            all_models.append(this_model)
    optimizers = []
    for each_model in all_models:
        optimizer = SGD(each_model.parameters(), lr=LEARNING_RATE)
        optimizers.append(optimizer)
    return all_models, optimizers, criterion


def get_weight_and_bias_from_torch(model):
    layer = model.linear
    weight = layer.weight.data.detach().cpu().numpy().ravel()
    try:
        bias = layer.bias.data.detach().cpu().numpy().ravel()[0]
    except:
        bias = layer.bias.data.detach().cpu().numpy().ravel()
    return weight, bias
